MHANet: A hybrid attention mechanism for retinal diseases classification

With the increase of patients with retinopathy, retinopathy recognition has become a research hotspot. In this article, we describe the etiology and symptoms of three kinds of retinal diseases, including drusen(DRUSEN), choroidal neovascularization(CNV) and diabetic macular edema(DME). In addition, we also propose a hybrid attention mechanism to classify and recognize different types of retinopathy images. In particular, the hybrid attention mechanism proposed in this paper includes parallel spatial attention mechanism and channel attention mechanism. It can extract the key features in the channel dimension and spatial dimension of retinopathy images, and reduce the negative impact of background information on classification results. The experimental results show that the hybrid attention mechanism proposed in this paper can better assist the network to focus on extracting thr fetures of the retinopathy area and enhance the adaptability to the differences of different data sets. Finally, the hybrid attention mechanism achieved 96.5% and 99.76% classification accuracy on two public OCT data sets of retinopathy, respectively.


Introduction
Diabetic retinopathy (DR) is a common blinding disease, mostly in diabetic patients [1]. Changes in blood components of diabetic patients cause dysfunction of vascular endothelial cells, which leads to impaired retinal barrier in diabetic patients. According to a study by the International Diabetes Federation (IDF), patients with diabetes over 10 years have a 60.00% chance of developing retinopathy. Periodic examinations of diabetic patients and effective control of retinal diseases can prevent patients from transient or permanent blindness [2]. Diabetic macular edema (DME) is one of the main symptoms of diabetic retinopathy (DR) [3]. The main reason for this disease is the long-term hyperglycemia state of diabetic patients, which leads to increased vascular permeability of the retina and choroid of the eye, and the water molecules and some protein components in the blood can permeate through the damaged vascular walls and form edema in the macular area.
Another common retinopathy is age-related macular degeneration (AMD) [4,5]. At present, the exact cause of the disease is unknown. It may be related to genetic factors, environmental effects, chronic retinal light injury, nutritional disorders, metabolic disorders and so on. However, aging and degeneration are important factors that cause age-related macular a1111111111 a1111111111 a1111111111 a1111111111 a1111111111 degeneration. According to clinical manifestations and pathological changes, drusen(DRU-SEN) and choroidal neovascularization (CNV) are one of the main symptoms of AMD.
Among them, the early symptoms of CNV are subretinal neovascularization, but the symptoms are not obvious. In the middle stage, the blood vessels will gradually expand to a certain extent, causing leakage, rupture and bleeding, which can lead to vision loss. In the later stage, with the aggravation of vascular infiltration and bleeding, the patient's vision was severely damaged, which caused the patient's permanent visual impairment. DRUSEN is often caused by metabolic disorders of the pigment epithelium. In the early stage, the patient's visual function may not be impaired. However, as the condition worsens, some patients have enlarged physiological blind spots in their eyes. If the situation is more serious, it may cause the patient's field of view to narrow or vision loss. AMD has become the most important blinding disease in western developed countries [4,5]. As the macular area is the most important part of the retina and the most sensitive part of vision, 80.00% of human vision comes from the macular area. Once a lesion such as edema occurs in the macular area, the impact on vision is very serious. Therefore, regular checking for diabetic patients and effective control of the disease in early stage can prevent temporary or permanent blindness in [2].
With the rapid increase in the number of patients with retinopathy, hospitals need a lot of manpower and material resources to face complex situations. In order to reduce the burden on hospitals and doctors, the auxiliary diagnosis and analysis technology of retinopathy is an urgent technology for doctors. OCT is a new tomography technology developed in recent years [6][7][8][9]. Compared with the traditional fundus detection technology, OCT can perform non-contact and non-invasive tomography on the microstructure of living eye tissue [10]. Therefore, this paper uses the OCT retinal fundus detection image datasets for our research.
With the wide application of neural networks in many fields [11], medical image classification based on deep learning has also become a research hotspot. Many scholars have tried to use neural networks to identify retinopathy, and achieved good results. In order to better lock foreground information of the picture, image denoisingis often used in image processing [12,13]. In addition, the attention mechanism can also lock foreground information of the image very well, so this article explores the role of the attention mechanism in the recognition of retinal diseases. Attention mechanism was originally used in machine translation [14], and now it has become an important concept in image recognition. When recognizing an image, it is similar to the human visual system. It can assist the network to pay more attention to the key information in the image. In order to better identify retinal diseases, this paper proposes a new attention mechanism, which can better improve the accuracy of retinal disease recognition.
At present, the attention mechanism in image recognition is mainly divided into channel attention mechanism and spatial attention mechanism. The channel attention mechanism is used to distinguish the relationship between the feature map channels and the spatial attention mechanism is used to distinguish the relationship between the elements in the feature map. Because the lesion area in OCT images of retinopathy has the following characteristics: (1) The lesion features are concentrated in the local area. (2) The characteristics of the lesions are not much different from those of normal tissues. Therefore, this paper proposes a multi-branch hybrid attention network (MHANet) to lock the lesion area of the picture. Specifically, the channel attention mechanism is mainly used to highlight the channel elements with the most abundant lesion features; Spatial attention mechanism is mainly used to highlight the elements of the area where the lesion features are located. Experimental results show that this algorithm has certain advantages compared with previous mainstream algorithms. In addition, the visualization results show that compared with the mainstream model using channel attention mechanism alone, the hybrid attention mechanism proposed in this paper can more accurately lock the location of the lesion area in retinal OCT images.

Related work
With the development of machine learning and deep learning, medical image analysis based on deep learning has become a research hotspot for many scholars.
In 2014, Srinivasan et al. [15] applied machine learning methods to image analysis of retinal diseases. They used the support vector machine (SVM) as the classification tool and the HOG feature of the OCT images as the classification basis to classify and recognize the three types of retinopathy images. They constructed a multi class classifier using three linear SVM, and then classified NORMAL and AMD, NORMAL and DME, AMD and DME. Finally, AMD, DME and NORMAL achieved 100.00%, 100.00% and 86.67% classification accuracy respectively. This experimental method not only increases the complexity of the experiment but also does not fundamentally distinguish the three types of pictures at the same time. In addition, in order to promote research of ophthalmic diseases, they also created a public OCT dataset for the majority of scholars to use.
In 2016, Wang et al. [16] used same data set as Srinivasan et al. [15] to further deepen research on retinopathy. They first used two feature selection algorithms CFS (Correlationbased Feature Subset) and CSE (Classifier Subset Evaluator) to find a subset of Linear Configuration Pattern (LCP) features [17,18], and then used the Sequence Minimization Optimization (SMO) algorithm as a classification tool to recognize and classify the selected image features. Unlike Srinivasan et al., they combine the local and global features of the picture. The experimental results show that Wang et al achieved 97.80%, 94.00% and 99.60% accuracy on AMD, DME and NORMAL respectively. Although the feature selection algorithm can reduce the redundant or useless features in the picture, some key features will be lost in the process of feature selection.
In 2017, Rasti et al. [13] used a multi-scale neural network to analyze retinopathy images. They used parallel convolution kernels of different sizes to extract the features of different regions of the feature map, and then set up a gated network to assign corresponding coefficients to the feature maps extracted by different convolution kernels to adjust the relationship between different feature maps. In addition, they also created a new OCT retinal disease dataset. For the dataset created by Srinivasan et al., the overall classification accuracy of AMD, DME and NORMAL is 99.39%. For their own datasets, the accuracy rates of AMD, DME, and NORMAL are 95.67%, 98.22% and 96.67%, respectively. Although the gated network can reconcile the relationship between the feature map and the feature map, it does not highlight the importance of the channel or region of the feature map itself.
In 2017, Karri et al. [19] used a fine-tuning pre-training model (GoogleNet) to study the OCT images of retinopathy provided by Srinivasan et al. They first changed the fully connection layer of the last layer of the network to 3 outputs, then called the parameters pre-trained on ImageNet in advance as the initialization parameters of the network, and finally used the retinopathy dataset for network training. The accuracy of the final experiment on AMD, DME and NORMAL were 89.00%, 86.00% and 99.00% respectively. Although the pre training model can make the network quickly reach the convergence state, the final classification effect needs to be improved.
In 2019, Feng et al. [20] conducted a classification and recognition study on four types of eye retina images. They first loaded the pre training parameters of VGG16 trained in Imagenet, and then changed the full connection layer of the last layer to four outputs. Finally, the overall accuracy of 97.80% was achieved on CNV, DME, DRUSEN and NORMAL. Although the migration learning method proposed in literature [19,20] can reduce the training time of the network and alleviate the network's excessive dependence on the dataset, it reduces the generalization of the network to different datasets. For OCT image classification research, Wang et al. [21] compared VGG16 [22], VGG19 [22], Inception-v3 [23] with CliqueNet [24], DPN92 [25], DenseNet121 [26], ResNet50 [27], ResNext101 [28]. Experiments show that the network with jump connection operation can reduce the loss of effective information in the process of feature extraction and significantly improve the classification effect.
With the development of a deep learning network, current scholars are pursuing the lightness and convenience of the network. Ding X et al. [29] proposed the RepVGG network. RepVGG adds a jump connection operation to the network based on the VGG network. On the premise of ensuring training speed and training accuracy, the accuracy of the RepVGG network on ImageNet dataset can reach more than 80%, therefore, this article uses the RepVGG network to study the retina OCT images.
In addition, the outstanding achievements of attention mechanism in the field of image have also attracted the interest of many scholars. In 2019, Jie et al. [30] proposed SENet network architecture. The core idea of the network is to assign the corresponding channel coefficients to the multi-channel feature maps, so as to highlight the relationship between the channels of the feature maps. Soon after the SENet network was proposed, its enhanced network architecture SKNet [31] was born. Compared with SEnet, SKNet studies the channel-to-channel relationship between two feature maps and has achieved better results on the ImageNet dataset.
According to the characteristics of lesion areas in OCT retinopathy images, we propose a hybrid attention mechanism, which is composed of a channel attention mechanism and a spatial attention mechanism in parallel. It mainly has the following characteristics: 1. By calculating the channel coefficients of the feature maps extracted by the convolution kernels with different expansion rates, the network can automatically identify the importance of the corresponding channel elements between different feature maps.
2. By calculating the spatial coefficients of the feature maps extracted by the convolution kernels of different sizes, the network can automatically identify the importance of regional elements in different feature maps.

Datasets
Considering the influence of the number of datasets on the experiment, two public datasets are used in this paper. Among them, Dataset1 was provided by Rasti et al., obtained at the Noor Eye Hospital in Tehran using Spectralis SD-OCT imaging [13]. Dataset2 was provided by Kermany et al, and obtained by using Spectralis SD-OCT imaging [32]. The original train dataset of Dataset2 is 83484 OCT images from 4686 patients. There are 37,205 choroidal neovascularization (CNV), 8616 drusen (DRUSEN), 11348 diabetic macular edema (DME), and 26315 normal(NORMAL). The test dataset includes 250 pictures of CNV, DRUSEN, DME and NORMAL from 633 patients. In the training stage, due to the huge difference in the number of four types of pictures, the model focuses on learning large sample data, so that small sample data is not fully learned, and generalization ability of the network is reduced. Therefore, we processed Dataset2 as follows: 1: Using 8616 DRUSEN images as the standard, 8616 images were randomly selected from CNV, DME and NORMAL. 2: The number of each picture is 8616, we divide it into train dataset and test dataset according to the ratio of 8:2.
Dataset1 has a total of 4254 images, of which 1104 are DME, 1565 are AMD, and 1585 are NORMAL. Similarly, we divide the pictures into train dataset and test dataset according to the ratio of 8:2. Table 1 shows the main statistics of the two datasets in this article.
Since the quality and size of the two public data sets are different, we used these two data sets for training and evaluation respectively. Experimental data shows that our network model has good generalization ability. Fig 1 show examples of datasets.

Network architecture
In this section, we will introduce our proposed Multi-branch hybrid attention network (MHA-Net) in detail. As shown in Fig 2, MHANet consists of two parts: channel attention module and spatial attention module. The main function of the channel attention module is to assign channel coefficients to different feature maps to identify the channel-to-channel correlation between the feature maps. The main function of the spatial attention module is to assign spatial coefficients to different feature maps to distinguish the importance of location information between the feature maps.

PLOS ONE
As shown in Fig 2, U 2 R C×H×W is the feature map, C is the number of channels,W and H represent length and width of the feature map, respectively. First, we performed three convolution operations on the input feature map: F1: U ! Ua 2 R C×H×W , F2: U ! Ub 2 R C×H×W , and F3: U ! Uc 2 R C×H×W . Among them, F1 means using a convolution operation with a convolution kernel size of 3 × 3 and an expansion rate of 2. F2 represents a convolution operation with a convolution kernel size of 3 × 3 and an expansion rate of 1. F3 represents a convolution operation with a convolution kernel size of 5 × 5 and an expansion rate of 1. Secondly, the element summation operation is carried out, and the feature maps from different branches are fused to obtain Ta 2 R C×H×W and Tb 2 R C×H×W respectively.
Channel attention model of MHANet. In order to identify the channel-to-channel relationship between different feature maps, we use feature maps Ta 2 R C×H×W to obtain channel attention coefficients. The specific operations are shown in the following steps: In order to reduce redundant information in the feature map and improve fault tolerance of the model, we first perform the maximum pooling and average pooling operations on the feature map Ta in the channel dimension to obtain two new feature maps {A, E} 2 R C×1×1 . Then we add and fuse the two new feature maps so that we can get foreground information of the picture as much as possible.
The maximum pooling and average pooling in the channel dimension are to take a maximum value and an average value from the feature map element values of each channel respectively; A i 2 R 1×1 and E i 2 R 1×1 are obtained from the i-th channel feature map of Ta through maximum pooling and average pooling respectively.
In addition, in order to assign channel coefficients to the feature maps Ua and Ub, we perform a convolution operation on the feature map Pa to achieve the channel dimension upgrade to obtain the feature map Pb 2 R 2C×1×1 , in which the size of the convolution kernel is 1 × 1. Then the feature map Pb is divided equally in the channel dimension to obtain the
Among them, Pb i 2 R 1×1 and Pb i+C 2 R 1×1 represent the element values in the i and i + C channels of the feature map Pb respectively; K i 2 R 1×1 and L i 2 R 1×1 represent the element values in the i-th channel of the feature maps K and L respectively.
Finally, we perform a Softmax operation on K and L to obtain the channel attention coefficients A tta 2 R C×1×1 and A ttb 2R C×1×1 , and then we multiply Atta and Attb with Ua and Ub respectively, and the two feature maps obtained after the multiplication are added and fused to obtain the feature map Va 2 R C×H×W .
Atta i 2 R 1×1 and Att b i 2 R 1×1 respectively represent the corresponding coefficients obtained from the element values K i and L i , and Atta i + Attb i = 1; Ua i 2 R H×W and Ub i 2 R H×W respectively represent all the element values of the i-th channel of the feature maps Ua and Ub.
Position attention model of MHANet. In order to enable the network to distinguish the importance of elements in different feature maps, we use the feature map Tb 2 R C×H×W to calculate the spatial attention coefficient. The specific operation includes the following steps.
In order to assign spatial coefficients to the feature maps Ua and Ub, we perform a convolution operation on the feature map Tb to achieve the channel dimension upgrade to obtain the feature map M 2 R 2C×H×W , in which the size of the convolution kernel is 1 × 1. Then the feature map M is divided equally in the channel dimension to obtain the feature maps N 2 R C×H×W and Z 2 R C×H×W . Finally, we use the Sigmoid activation function on the feature maps N and Z to obtain the spatial attention coefficients Attc 2 R C×H×W and Attd 2 R C×H×W , and then multiply the spatial attention coefficients Attc and Attd with the feature maps Ub and Uc respectively, and the two feature maps obtained after the multiplication are added and fused to obtain the feature map N (j,h,w) and Z (j,h,w) respectively represent the element at the position (h, w)� in the j-th channel of the feature maps N and Z; Attc (j,h,w) and Attd (j,h,w) are obtained from the corresponding N (j,h,w) and Z (j,h,w) respectively. The specific formula is shown above.
Finally, we concatenate the feature maps V1 and V2 to obtain the feature map V 2 R 2C×H×W , and then use the convolution operation with the convolution kernel size of 1 × 1 to reduce the channel dimension of the feature map V to obtain the feature map O 2 R C×H×W . It is worth noting that the final output feature map highlights the foreground information of the picture in both the channel dimension and the spatial dimension.
In addition, we provide pseudo code to introduce the overall flow of the experiment, as shown in Table 2.

Experimental details
The experiments were run based on the Python 3.6, Torch 1.6.0. In the process of network training, after many experiments, it is found that the learning rate is 0.001, the batch size is set to 224, and the network can achieve the optimal effect. The loss function is 'cross-entropy loss', and the optimization method is the SGD gradient descent method.

Statistical analysis method
Because different evaluation indexes have different evaluation significance to the experimental results, three common evaluation indexes are used to evaluate the experimental results in this paper: accuracy rate, precision rate, and recall rate. Also, to make the experiment more concise and understandable, we use the mixed matrix to show the classification results of each kind of data. Accuracy, precision, and recall are defined as follows:

PLOS ONE
TP represents the number of samples that are actually positive and predicted to be positive; TN represents the number of samples that are actually negative and predicted to be negative; FP represents the number of samples that are actually negative and predicted to be positive; FN represents the number of samples that are actually positive and predicted to be negative.

Experimental results
In order to better illustrate the positive effect of the hybrid attention mechanism on the recognition of retinopathy, we give the classification accuracy of the entire dataset and the classification index score of each type of data. From the accuracy score of each model on Dataset1, it can be seen that the attention mechanism plays an active role in the three types of image recognition of AMD, DME, and NOR-MAL. In particular, our proposed MHANet network achieved the best results in retinal image classification, and the overall accuracy reached 99.76%, as shown in Table 3.  Table 4 shows the Precision, Recall, and F1 of each type of picture. It can be seen from the table that the F1 values obtained by ResNet50, Res2Net50, SENet and SKNet on AMD and NORMAL are significantly better than their F1 values obtained on DME, which indicates that generalization ability of the above models has certain defects, and hybrid attention proposed in this article effectively alleviates this problem.

PLOS ONE
For the Dataset2, the four types of pictures are divided according to the ratio of 1:1:1:1, so the network will not have the over fitting problem caused by the large difference in the number of training pictures. The MHANet proposed in this paper adopts both spatial attention mechanism and channel domain attention mechanism, and achieves the best results. See Table 5 for specific experimental data.

PLOS ONE
As shown in Table 6, the MHANet network proposed in this paper not only achieved the best results in accuracy, but also has certain improvements in Precision, Recall, and F1 compared to other networks. This is because the MHANet network can not only automatically highlight the importance of channels, but also distinguish the importance of element values in the entire feature map. In order to better prove the feasibility of the innovation in this paper, this paper further demonstrates the experimental results of Dataset1 and Dataset2 through a confusion matrix. The results are shown in Figs 3 and 4.
As shown in Fig 3, we give the confusion matrix obtained by each model based on Dataset1, where 0.0 represents AMD, 1.0 represents DME and 2.0 represents NORMAL. From each confusion matrix, we can know that other models will more or less misclassify these three types of pictures. MHANet proposed in this paper can completely and correctly classify DME and NORMAL, and minimize the error of AMD as DME. Among them, 0.0 represents CNV, 1.0 represents DRUSEN, 2.0 represents DME, and 3.0 represents NORMAL. From the perspective of the model, the MHANet model has the largest number of correct image classifications among the five models. Among them, MHANet has 298 correctly classified pictures more than VGG16, 142 correctly classified pictures more than RepVGG, 80 correctly classified pictures more than ResNet50, 72 correctly classified pictures more than Res2Net50, 88 correctly classified pictures more than SENet and 77 correctly classified pictures more than SKNet. This fully proves that this model has achieved the best effect in the task of retinal disease recognition. In addition, other models are prone to mispredict sample 2 as sample 0 and mispredict sample 3 as sample 2. For these two kinds of images that are easy to distinguish errors, MHANet has the best classification effect.

PLOS ONE
In order to analyze the generalization ability of the model, we give the accuracy curve during model training and the accuracy curve during testing, as shown in Figs 5-8. Among them, Figs 5 and 6 are the accuracy curves during training and testing based on Dataset1. It can be seen from the figure that VGG16 and RepVGG can get very high scores during training, but the test results are not satisfactory, indicating that generalization ability of VGG16 and RepVGG on Dataset1 is not strong. ResNet50 and Res2Net50 have achieved good results in both the train stage and the test stage, indicating that ResNet50 and Res2Net50 have better generalization capabilities on Dataset1, but ResNet50 is better than Res2nNet50 on the test set. This is because Res2Net50 has more parameters than ResNet50, which makes it difficult for the network to reach the best state. Similarly, it can be seen from the figure that SENet and SKNet have good generalization ability on Dataset1, but the test results have not been greatly

PLOS ONE
improved, so the network has certain limitations. Finally, we can see that the MHANet network can get the best results in the train stage and the test stage, which proved that MHANet has good convergence and has made breakthrough. Figs 7 and 8 show the accuracy during training and testing based on Dataset2. It can be seen from the figure that compared with Dataset1, RepVGG performed better on Dataset2. In addition, ResNet50, Res2Net50, SENet, SKNet and MHANet achieved convergence in the train and test stages, but MHANet achieved the best experimental results in the testing phase. This once again proved that the MHANet network can get good experimental results on small datasets, and good results can be obtained on large datasets.
As shown in Fig 9, The AUC values obtained by SENet, SKNet and MHANet are all above 0.96 on Dataset 1 and the average AUC values of three types are all above 0.97,which shows

PLOS ONE
that attention mechanism model we trained on Dataset 1 can well balance the precision and recall indicators. Among them, the AUC values of the MHANet model proposed in this paper all reached 1, which also shows that the hybrid attention mechanism can have a better effect in retinal disease classification experiments than using only the channel attention mechanism.
As shown in Fig 10, The AUC values obtained by ResNet, Res2Net, SENet, and SKNet on Dataset 2 and the average AUC values of the three types are not much different, which shows that it is difficult to improve the discrimination ability of the network on Dataset2 using only the channel attention mechanism. The AUC value obtained on Dataset2 and the average AUC value of the three types of MHANet, which is composed of the parallel channel attention mechanism and the spatial attention mechanism, are the highest.
As shown in Figs 11 and 12, we show the focusing ability of different networks. Among them, ResNet50 without attention mechanism has completely different focusing ability on two different data sets. On the Dataset1, the focus range of ResNet50 is wide and the focus center does not accurately fall on the foreground information part. On the Dataset2, ResNet50 shrinks the focus range but the focus center still does not accurately fall on the foreground information part. SENet and SKNet adopt different channel attention mechanisms that cause their focusing abilities to be different. On the two datasets, the focusing ability of SKNet is obviously better than that of SENet. However, the focus centers of SENet and SKNet are still not accurately gathered in foreground information. The MHANet network we proposed has a small focus range on two different datasets and the focus center falls on the foreground information part.

Conclusion
According to the distribution of lesion features in retinopathy images, we propose a hybrid attention mechanism to help the network lock foreground information in the image more accurately. The MHANet network module is composed of a parallel channel attention mechanism and a spatial attention mechanism. The channel attention mechanism assigns different channel coefficients to each channel feature map in the channel dimension, so as to highlight the channel feature map with the most abundant lesion features. The spatial attention mechanism assigns corresponding position coefficient to each element in the spatial dimension, so as to highlight the importance of elements in the lesion area. This hybrid attention mechanism can help the network focus on the lesion area in both the spatial dimension and the channel dimension. In this paper, the feasibility of the MHANet module is verified on two public retinopathy datasets. The experimental results show that the MHANet module can not only improve the accuracy of network classification and recognition, but also improve the generalization ability of the network.